{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0a1f1524-f670-4746-b9b9-ffba21db6c74",
   "metadata": {},
   "source": [
    "# Baseline solution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1747976c-8331-4119-993b-48eec03be5b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip3 install emoji transformers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61cd0956-563f-4daf-ab5b-1e4bd1c91306",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8da6b19-51fb-4c1c-8df7-0ad579126ab2",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2fb303da-7dd3-4777-ad20-2220f802a40b",
   "metadata": {},
   "source": [
    "## Read"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1a29e40a-d19c-44a9-a55c-dcc9bfbc66d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "\n",
    "data_directory_path = Path(\"../data/\")\n",
    "data_path = data_directory_path / \"train_dataset_train.csv\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa74446e-e062-44d9-a17c-0343954c6162",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e397766-eca0-4dbd-a42a-b5bc2b259439",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = pd.read_csv(data_path, sep=';')#.tail(100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d87c1fe-bf02-4d4b-b075-1a4f0b3d7c89",
   "metadata": {},
   "outputs": [],
   "source": [
    "data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4ac4bb13-72bc-490e-b3cf-95ed06371420",
   "metadata": {},
   "source": [
    "Все тексты инцидентов начинаются с \"'\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd1519c2-2c0f-4791-85f5-ed622bf8cd9f",
   "metadata": {},
   "outputs": [],
   "source": [
    "data[\"Текст инцидента\"].apply(lambda x : (x[0] == \"'\")).all()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6857280c-f8f0-45ba-a312-a465c1f5badf",
   "metadata": {},
   "source": [
    "Каждой теме соответствкет только одна группа тем:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71643b34-1d70-4e35-b96d-5cb2cbf6e9f4",
   "metadata": {},
   "outputs": [],
   "source": [
    "(data.groupby([\"Тема\"])[[\"Группа тем\"]].nunique() > 1).any()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4cea9ed6-1d22-4e93-93d8-710f7b6dd0ae",
   "metadata": {},
   "source": [
    "Нет тем, которые со звёздочкой и без одновременно:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0793864-3c32-4d63-b4e3-5ec10618a0bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "themes = set(data[\"Тема\"].unique())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46b773d3-ce5b-4f1f-abbf-d606ab1bb515",
   "metadata": {},
   "outputs": [],
   "source": [
    "themes.intersection(set(x[2:] for x in data[\"Тема\"].unique()))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f4300533-3900-4514-a64e-410d8a75e468",
   "metadata": {},
   "source": [
    "## Preprocess"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "77267299-f595-466c-bcd5-da242a134fc4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import preprocess"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "433ce6f4-3126-43fd-9d8c-c8a11311479e",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = preprocess.preprocess_data(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "63cac6b5-eea5-4bc4-b3c1-88db97fd3b6a",
   "metadata": {},
   "outputs": [],
   "source": [
    "data[\"Текст инцидента\"][data[\"Текст инцидента\"].str.contains(\"в шоке\")]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "58140a3e-8941-44aa-bcf4-439cf3c1129b",
   "metadata": {},
   "source": [
    "## Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d47e30f-18df-4e81-adc9-fafaa2a423c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenize = lambda x : tokenizer(x, padding=True, truncation='only_first', return_tensors=\"pt\").to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd3796ce-7249-4cab-bb2f-0cb2f631459c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.dataset import prepare_data_for_dataset, MessagesDataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "740de1b7-07ff-4134-8562-83a3fd5c5b39",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.preprocessing import LabelEncoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "603e5cc9-548a-4d8b-9f24-fa8678a0b8a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "executor_encoder = LabelEncoder().fit(data[\"Исполнитель\"])\n",
    "group_encoder    = LabelEncoder().fit(data[\"Группа тем\"])\n",
    "theme_encoder    = LabelEncoder().fit(data[\"Тема\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25cc34ab-3c52-498a-b527-993f7e47555c",
   "metadata": {},
   "outputs": [],
   "source": [
    "data[\"Исполнитель\"] = executor_encoder.transform(data[\"Исполнитель\"])\n",
    "data[\"Группа тем\"]  = group_encoder.transform(data[\"Группа тем\"])\n",
    "data[\"Тема\"]        = theme_encoder.transform(data[\"Тема\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d69515d-3bbe-4dec-a3c1-0a888c021630",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f4090617-032b-403f-bd83-43a2223432b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_executor, test_executor, train_group, test_group, train_message, test_message, train_theme, test_theme = train_test_split(\n",
    "    *prepare_data_for_dataset(data, device),\n",
    "    test_size=0.25,\n",
    "    random_state=42,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "903a0823-1bc0-4bb0-b6e8-e84a72598d53",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = MessagesDataset(train_message, train_executor, train_theme, train_group)\n",
    "test_dataset  = MessagesDataset(test_message, test_executor, test_theme, test_group)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "673c59ba-e7bc-470e-a43d-4ed46ff64325",
   "metadata": {},
   "source": [
    "## Model\n",
    "\n",
    "Используется двуголовая модель."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e897e31c-07ac-497e-bab6-2470f87d8082",
   "metadata": {},
   "outputs": [],
   "source": [
    "models_path = Path(\"../models\")\n",
    "local_weights_filename = \"distilbert.pt\"\n",
    "load_local_weights = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4dc3c04a-dca8-4046-86eb-86c4c211cc39",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoConfig, AutoModel, AutoTokenizer\n",
    "\n",
    "# \"sberbank-ai/ruRoberta-large\"\n",
    "# \"DeepPavlov/distilrubert-base-cased-conversational\"\n",
    "# \"xlm-roberta-base\"\n",
    "language_model_name = \"DeepPavlov/distilrubert-base-cased-conversational\"\n",
    "\n",
    "if load_local_weights:\n",
    "    language_model = AutoModel.from_config(AutoConfig.from_pretrained(language_model_name))\n",
    "else:\n",
    "    language_model = AutoModel.from_pretrained(language_model_name)\n",
    "tokenizer = AutoTokenizer.from_pretrained(language_model_name, model_max_length=512)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14136fd6-1828-4b43-9409-1132ab326d22",
   "metadata": {},
   "outputs": [],
   "source": [
    "for params in language_model.parameters():\n",
    "    params.requires_grad = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13b50aa5-cd30-45e0-b82c-99c102ba51e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.classifier import Classifier, evaluate_classifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "050a4db5-7b3b-426f-9b7a-6b9021f1d524",
   "metadata": {},
   "outputs": [],
   "source": [
    "classifier = Classifier(language_model, tokenizer, preprocess.theme_to_group(data), executor_encoder, theme_encoder, hid_size=768).to(device)\n",
    "\n",
    "if load_local_weights:\n",
    "    classifier.load_state_dict(torch.load(models_path / local_weights_filename))\n",
    "    classifier = classifier.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0d495c8-be6d-4f17-9083-36b8048f855f",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)\n",
    "test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8bfdac71-f2d1-4cfa-8f36-57dff682de21",
   "metadata": {},
   "outputs": [],
   "source": [
    "nllloss = torch.nn.NLLLoss()\n",
    "#classifier_loss = lambda y_pred, y: nllloss(y_pred[0], y[0]) + nllloss(y_pred[1], y[1])\n",
    "classifier_loss = lambda y_pred, y: 0.2 * nllloss(y_pred[0], y[0]) + 0.8 * nllloss(y_pred[1], y[1])\n",
    "evaluate_classifier(classifier, test_dataloader, classifier_loss, device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1b62c66-60d0-485d-9d41-45ae138ee14d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.display import clear_output\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0337a69-ac51-41ed-9840-c92433ebe34f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import f1_score\n",
    "from tqdm import tqdm\n",
    "\n",
    "def train_classifier(classifier, train_dataloader, test_dataloader, classifier_loss, device, lr=1e-6, n_epochs: int=5):\n",
    "    classifier_metrics = {\"train_loss\" : [], \"test_loss\" : [], \"f1_score_et\": [], \"f1_score_tg\": []}\n",
    "\n",
    "    opt = torch.optim.Adam(classifier.parameters(), lr=lr)\n",
    "\n",
    "    for epoch in range(n_epochs):\n",
    "        avg_loss = 0.0\n",
    "        total_samples = 0\n",
    "\n",
    "        for index, batch in enumerate(tqdm(train_dataloader)):\n",
    "            message, executor, theme, group = batch\n",
    "            del batch\n",
    "            \n",
    "            batch_size = theme.shape[0]\n",
    "\n",
    "            tokens = classifier.tokenizer(message, padding=True, truncation='only_first',\n",
    "                                          return_tensors=\"pt\").to(device)\n",
    "            del message\n",
    "\n",
    "            y_pred_1, y_pred_2 = classifier(tokens)\n",
    "            _loss = classifier_loss((y_pred_1, y_pred_2), (executor, theme))\n",
    "            del y_pred_1\n",
    "            del y_pred_2\n",
    "            del executor\n",
    "            del theme\n",
    "\n",
    "            opt.zero_grad()\n",
    "            _loss.backward()\n",
    "            opt.step()\n",
    "\n",
    "            avg_loss += _loss.item() * batch_size\n",
    "            total_samples += batch_size\n",
    "            \n",
    "        avg_loss /= total_samples\n",
    "        classifier_metrics[\"train_loss\"].append(avg_loss)\n",
    "\n",
    "        test_loss, test_f1_score_et, test_f1_score_tg = evaluate_classifier(classifier, test_dataloader, classifier_loss, device)\n",
    "        classifier_metrics[\"test_loss\"].append(test_loss)\n",
    "        classifier_metrics[\"test_f1_score_et\"].append(test_f1_score_et)\n",
    "        classifier_metrics[\"test_f1_score_tg\"].append(test_f1_score_tg)\n",
    "\n",
    "        clear_output(True)\n",
    "        plt.figure(figsize=(18,4))\n",
    "        for index, (name, history) in enumerate(sorted(classifier_metrics.items())):\n",
    "            plt.subplot(1, len(classifier_metrics), index + 1)\n",
    "            plt.title(name)\n",
    "            plt.plot(range(1, len(history) + 1), history)\n",
    "            plt.grid()\n",
    "\n",
    "        plt.show();\n",
    "        #print(\"Mean loss=%.3f\" % np.mean(metrics['train_loss'][-1:], axis=0)[1], flush=True)\n",
    "\n",
    "    return classifier_metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a3064880-0a21-4092-94b6-98f3bf361199",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_classifier(classifier, train_dataloader, test_dataloader, classifier_loss, device, lr=1e-3, n_epochs=6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f416ae6-e435-4e8d-a1cb-4dfeeb565b79",
   "metadata": {},
   "outputs": [],
   "source": [
    "for params in classifier.parameters():\n",
    "    params.requires_grad = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3301445b-35f9-4c1b-b069-cb35488e22da",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_classifier(classifier, train_dataloader, test_dataloader, classifier_loss, device, lr=1e-6, n_epochs=30)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6cdf356f-294d-41e5-8e77-77f477d7cd2b",
   "metadata": {},
   "outputs": [],
   "source": [
    "evaluate_classifier(classifier, test_dataloader, classifier_loss, device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "44b4fb34-d8ff-4196-bbbc-b7bf30e2c2f4",
   "metadata": {},
   "outputs": [],
   "source": [
    "classifier.predict(data[\"Текст инцидента\"][:10].tolist(), device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eedcfc9a-18fb-459f-ad35-e9e8d08270a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "theme_encoder.inverse_transform(data[\"Тема\"][:10].tolist())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "281a4ef0-d54f-4b9d-8075-f09569892967",
   "metadata": {},
   "source": [
    "## Save model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d770833-fe29-46e8-a69d-330e710bf32a",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(classifier.state_dict(), models_path / \"distilbert2.pt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "873824b8-764b-4aca-80b3-63ca2ce6685f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "\n",
    "with open(models_path / \"executor_encoder.obj\", 'wb') as file:\n",
    "    pickle.dump(executor_encoder, file)\n",
    "\n",
    "with open(models_path / \"theme_encoder.obj\", 'wb') as file:\n",
    "    pickle.dump(theme_encoder, file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9c0be2f7-c616-458a-b9fb-793091c6de79",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
